{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# 多层异构检索架构的实现\n",
    "Multi-layer Heterogeneous Vector Indexer (MHVI)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [],
   "source": [
    "#头文件在这里集中import\n",
    "import numpy as np\n",
    "import time\n",
    "import os\n",
    "import struct\n",
    "import faiss\n",
    "import subprocess\n",
    "import matplotlib.pyplot as plt\n",
    "import json"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [],
   "source": [
    "#数据集参数和索引参数等\n",
    "written = True #是否写入文件。避免测试时覆盖已写入文件\n",
    "evaluation = True #是否测试聚类精度。设置为False则是只训练聚类并导出\n",
    "type = np.float32 #指定数据集和查询集的数据类型。GT集恒为np.uint32\n",
    "# type = np.uint8\n",
    "\n",
    "#设置文件路径.注意所有向量数据集都转化为bin格式的，即头4个字节表示向量数量、紧接着4个表示每个向量的维数，然后是每个向量。里面不包含每个特征的大小，通过程序参数指定\n",
    "#sift1M\n",
    "dataset_path = \"/home/gary/Code/DiskANN/build/data/sift/sift_base.fbin\"\n",
    "query_path = \"/home/gary/Code/DiskANN/build/data/sift/sift_query.fbin\" #如果evaluation设置为False，则这两个参数可以不设置\n",
    "gt_path = \"/home/gary/Code/DiskANN/build/data/sift/sift_groundtruth.fbin\"\n",
    "# #bigann100M\n",
    "# dataset_path = \"/home/gary/Code/DiskANN/build/data/bigann/bigann_learn.bbin\"\n",
    "# query_path = \"/home/gary/Code/DiskANN/build/data/bigann/bigann_query.bbin\"\n",
    "# gt_path = \"/home/gary/Code/DiskANN/build/data/bigann/bigann_gt.ibin\"\n",
    "\n",
    "#聚类参数\n",
    "cluster_count = 1000 #设置聚类数量\n",
    "train_ratio = 0.1 #设置训练比例，例如1是全部向量用来训练\n",
    "\n",
    "# 输出文件夹\n",
    "output_folder = dataset_path.split(\"/\")[-1].split(\".\")[0]\n",
    "output_folder += \"_\" + str(cluster_count) + \"/\"\n",
    "\n",
    "# 聚类中心文件路径\n",
    "cluster_center_path = output_folder+\"cluster_centers.bin\"\n",
    "# 底层的偏移量文件路径\n",
    "offset_list_path = output_folder+\"offset_list.bin\"\n",
    "# 底层的聚类文件\n",
    "last_layer_path = output_folder+\"last_layer.bin\"\n",
    "\n",
    "#查询参数\n",
    "k = 10 #返回多少个近邻\n",
    "nprobe = 10 #查询多少个聚类\n",
    "\n",
    "#建图时的参数\n",
    "R = 32  #平均邻居数\n",
    "L = 50  #构建时的候选列表长度\n",
    "B = 120 #目标内存大小(GB)\n",
    "M = 120 #构建时的可用内存，越大越好(GB)\n",
    "\n",
    "#配置文件路径，用于给search程序传递参数\n",
    "config_path = \"dataset_config.txt\""
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [],
   "source": [
    "#检索结构，三层的索引都存在一个类中\n",
    "class MHVI_Index:\n",
    "    def __init__(self):\n",
    "        self.last_layer_index = None\n",
    "        self.middle_layer_index = None\n",
    "        self.top_layer_index = None\n",
    "\n",
    "    def __sizeof__(self):\n",
    "        pass\n",
    "\n",
    "mhvi_index = MHVI_Index()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [],
   "source": [
    "#读取bin格式的数据集文件，数据集、查询集、GT集都可以使用这个函数读取\n",
    "def read_bin(file_path, type):\n",
    "    file_size = os.path.getsize(file_path)\n",
    "    print(\"Read data from\", file_path, \"\\nfile size:\",file_size)\n",
    "    feature_size = np.dtype(type).itemsize\n",
    "    with open(file_path,\"rb\") as fd:\n",
    "        lines = int.from_bytes(fd.read(4), byteorder='little')\n",
    "        dim = int.from_bytes(fd.read(4), byteorder='little')\n",
    "        print(\"lines:\",lines,\"dim:\",dim)\n",
    "\n",
    "        data_size = lines * dim * feature_size\n",
    "        if(data_size+8 != file_size):\n",
    "            print(\"Error! file size and argument not match!\") # 判断实际文件大小是否与参数匹配，简单的纠错机制\n",
    "            return None\n",
    "\n",
    "        binary_data = fd.read(data_size)\n",
    "        vectors = np.frombuffer(binary_data, dtype=type)\n",
    "        vectors = vectors.reshape(lines, dim)\n",
    "        print(\"Returned vector list:\",vectors.shape, vectors.dtype)\n",
    "\n",
    "        return vectors\n",
    "    \n",
    "#计算Recall，I是搜索结果，gts是ground truth，k是计算前多少个搜索结果的召回率\n",
    "def compute_recall(I, gts, k):\n",
    "    num_queries = I.shape[0]\n",
    "    recall_sum = 0\n",
    "    \n",
    "    for i in range(num_queries):\n",
    "        retrieved = set(I[i])\n",
    "        ground_truth = set(gts[i][:k])  # 取前k个真实结果\n",
    "        correct = len(retrieved.intersection(ground_truth))\n",
    "        recall_sum += correct / k\n",
    "    \n",
    "    return recall_sum / num_queries\n",
    "\n",
    "#将向量集以.bin的格式写入文件中，主要用于写入质心\n",
    "def write_bin(filename, array):\n",
    "    # 检查输入是否为numpy数组\n",
    "    if not isinstance(array, np.ndarray):\n",
    "        raise ValueError(\"Input must be a NumPy array.\")\n",
    "    \n",
    "    # 获取数组的形状（行数和列数）\n",
    "    rows, cols = array.shape\n",
    "    \n",
    "    # 打开文件准备写入\n",
    "    with open(filename, 'wb') as f:\n",
    "        # 将行数和列数写入文件的前8个字节\n",
    "        np.array(rows, dtype=np.uint32).tofile(f)\n",
    "        np.array(cols, dtype=np.uint32).tofile(f)\n",
    "        \n",
    "        # 写入数组数据\n",
    "        # 确保数据类型正确，这里假设使用float32\n",
    "        array.tofile(f)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Read data from /home/gary/Code/DiskANN/build/data/sift/sift_base.fbin \n",
      "file size: 512000008\n",
      "lines: 1000000 dim: 128\n",
      "Returned vector list: (1000000, 128) float32\n",
      "[[  0.  16.  35. ...  25.  23.   1.]\n",
      " [ 14.  35.  19. ...  11.  21.  33.]\n",
      " [  0.   1.   5. ...   4.  23.  10.]\n",
      " ...\n",
      " [ 30.  12.  12. ...  50.  10.   0.]\n",
      " [  0.   5.  12. ...   1.   2.  13.]\n",
      " [114.  31.   0. ...  25.  16.   0.]]\n",
      "Read data from /home/gary/Code/DiskANN/build/data/sift/sift_query.fbin \n",
      "file size: 5120008\n",
      "lines: 10000 dim: 128\n",
      "Returned vector list: (10000, 128) float32\n",
      "[[  1.   3.  11. ...  42.  48.  11.]\n",
      " [ 40.  25.  11. ...   3.  19.  13.]\n",
      " [ 28.   4.   3. ...   2.  54.  47.]\n",
      " ...\n",
      " [  0.  15.  64. ...   3.  62. 118.]\n",
      " [131.   2.   0. ...   7.   0.   0.]\n",
      " [ 23.   0.   0. ...  79.  16.   4.]]\n",
      "Read data from /home/gary/Code/DiskANN/build/data/sift/sift_groundtruth.fbin \n",
      "file size: 4000008\n",
      "lines: 10000 dim: 100\n",
      "Returned vector list: (10000, 100) uint32\n",
      "[[932085 934876 561813 ... 398306 931721 989762]\n",
      " [413247 413071 706838 ... 855176 846198 987074]\n",
      " [669835 408764 408462 ... 310475 971815 937903]\n",
      " ...\n",
      " [123855 123351 534149 ...  90175 685486 416474]\n",
      " [755327 755323 840765 ... 595134 601257 172180]\n",
      " [874343 464509 413340 ... 360985 419949 223427]]\n"
     ]
    }
   ],
   "source": [
    "#读取数据集\n",
    "dataset = read_bin(dataset_path, type)\n",
    "print(dataset)\n",
    "\n",
    "queries = None\n",
    "gts = None\n",
    "if(evaluation):\n",
    "    queries = read_bin(query_path, type)\n",
    "    print(queries)\n",
    "\n",
    "    gts = read_bin(gt_path, np.uint32)\n",
    "    print(gts)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [],
   "source": [
    "#写入配置参数。在这里写入的原因是要读取到数据集的维度后再写入\n",
    "with open(config_path, \"w\") as f:\n",
    "    f.write(\"dataset_path=\" + dataset_path + \"\\n\")\n",
    "    f.write(\"query_path=\" + query_path + \"\\n\")\n",
    "    f.write(\"gt_path=\" + gt_path + \"\\n\")\n",
    "    f.write(\"cluster_center_path=\" + cluster_center_path + \"\\n\")\n",
    "    f.write(\"offset_list_path=\" + offset_list_path + \"\\n\")\n",
    "    f.write(\"last_layer_path=\" + last_layer_path + \"\\n\")\n",
    "    f.write(\"cluster_count=\" + str(cluster_count) + \"\\n\")\n",
    "    f.write(\"train_ratio=\" + str(train_ratio) + \"\\n\")\n",
    "    f.write(\"k=\" + str(k) + \"\\n\")\n",
    "    f.write(\"nprobe=\" + str(nprobe) + \"\\n\")\n",
    "    f.write(\"R=\" + str(R) + \"\\n\")\n",
    "    f.write(\"L=\" + str(L) + \"\\n\")\n",
    "    f.write(\"B=\" + str(B) + \"\\n\")\n",
    "    f.write(\"M=\" + str(M) + \"\\n\")\n",
    "    f.write(\"vector_count=\" + str(dataset.shape[0]) + \"\\n\")\n",
    "    f.write(\"dim=\" + str(dataset.shape[1]) + \"\\n\")\n",
    "    f.write(\"feature_type=\" + str(dataset.dtype) + \"\\n\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Start train\n",
      "Train done, spend: 6.1 second\n",
      "Start add vector\n",
      "Add done, spend: 8.2 second\n",
      "Writing index\n",
      "Write done\n"
     ]
    }
   ],
   "source": [
    "#用faiss的IVF库来构建IVF索引\n",
    "IVF_index = faiss.index_factory(dataset.shape[1], f\"IVF%d,Flat\"%cluster_count)  # 创建一个IVF索引\n",
    "\n",
    "# 训练索引\n",
    "print(\"Start train\")\n",
    "st = time.time()\n",
    "step = int(1/train_ratio) #步长，即我们间隔多远取一个向量来训练\n",
    "IVF_index.train(dataset[::step])\n",
    "et = time.time()\n",
    "print(\"Train done, spend: %.1f second\"%(et-st))\n",
    "\n",
    "# 添加向量到索引\n",
    "print(\"Start add vector\")\n",
    "st = time.time()\n",
    "IVF_index.add(dataset)\n",
    "et = time.time()\n",
    "print(\"Add done, spend: %.1f second\"%(et-st))\n",
    "\n",
    "#保存索引\n",
    "#创建用户名\n",
    "dataset_name = dataset_path.split(\"/\")[-1].split(\".\")[0]\n",
    "if(written):\n",
    "    print(\"Writing index\")\n",
    "    faiss.write_index(IVF_index, \"ivf_index_%s_%d.faiss\"%(dataset_name,cluster_count)) #输出的这个作用不大可以不管\n",
    "    print(\"Write done\")\n",
    "else:\n",
    "    print(\"Not write index\")\n",
    "\n",
    "#将IVF_index作为最后一层索引写入\n",
    "mhvi_index.last_layer_index = IVF_index"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "0.36961999999999423\n",
      "0.5349899999999993\n",
      "0.696519999999995\n",
      "0.8349099999999847\n",
      "0.9303999999999727\n"
     ]
    }
   ],
   "source": [
    "#第三层的召回精度测试，同时也是聚类数量对召回精度的影响测试\n",
    "if(evaluation):\n",
    "    for i in [1,2,4,8,16]: #搜索的聚类个数\n",
    "        IVF_index.nprobe = i\n",
    "        D, I = IVF_index.search(queries, k)  # 实际上搜索索引\n",
    "        print(compute_recall(I, gts, k))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "[[30.373627   3.879121   2.087912  ... 12.329671   8.747253  12.791209 ]\n",
      " [27.685484  17.75       7.5887094 ... 28.532257  40.443546  27.209677 ]\n",
      " [ 5.0958085  2.4191618  2.7724552 ...  1.4011977  3.1257486  7.904192 ]\n",
      " ...\n",
      " [63.402298  18.02299    6.5172415 ... 13.816092  21.402298  28.712643 ]\n",
      " [10.070796   8.769912   6.938053  ... 69.74336   29.38938    6.256637 ]\n",
      " [36.32558   12.139535  12.930232  ... 17.139534  37.069767  52.627907 ]]\n",
      "(1000, 128)\n",
      "float32\n",
      "cluster center has written to: sift_base_1000/cluster_centers.bin\n"
     ]
    }
   ],
   "source": [
    "#获取量化器，用于获取质心\n",
    "quantizer = IVF_index.quantizer\n",
    "centroids = np.zeros((cluster_count, dataset.shape[1]), dtype=np.float32)\n",
    "quantizer.reconstruct_n(0, cluster_count, centroids)\n",
    "\n",
    "print(centroids) # 打印质心，即聚类中心。质心并不是数据集中的一个点，只是一个虚拟的中心\n",
    "print(centroids.shape)\n",
    "print(centroids.dtype)\n",
    "\n",
    "#写入质心\n",
    "if(written):\n",
    "    write_bin(cluster_center_path, centroids)\n",
    "    print(\"cluster center has written to:\",cluster_center_path)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "metadata": {},
   "outputs": [],
   "source": [
    "#获取向量所属的聚类ID\n",
    "def get_cluster_id_for_vector(vector, quantizer):\n",
    "    # 使用 search 方法获取向量所属的聚类 ID\n",
    "    _, cluster_id = quantizer.search(np.expand_dims(vector, axis=0), 1) #第二个参数为返回个数\n",
    "    return cluster_id[0][0]\n",
    "\n",
    "#批量获取向量所属的聚类ID\n",
    "def get_cluster_id_for_vectors(vectors, quantizer):\n",
    "    # 使用 search 方法获取向量所属的聚类ID\n",
    "    _, cluster_ids = quantizer.search(vectors, 1)\n",
    "    #将cluster_ids转化为一个一维数组\n",
    "    cluster_ids = cluster_ids.flatten()\n",
    "    return cluster_ids\n",
    "\n",
    "#依次获取dateset所有向量的聚类ID，并记录在一个数组中。数组长度等于聚类数量，每个元素代表每个聚类的向量数量\n",
    "def statis_cluster(vectors, quantizer):\n",
    "    cluster_ids = get_cluster_id_for_vectors(vectors, quantizer)\n",
    "    cluster_ids_statis = [0]*cluster_count #统计每个聚类的出现次数\n",
    "\n",
    "    for id in cluster_ids:\n",
    "        cluster_ids_statis[id]+=1\n",
    "    return cluster_ids_statis, max(cluster_ids_statis), min(cluster_ids_statis)\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "metadata": {},
   "outputs": [],
   "source": [
    "# ids = get_cluster_id_for_vectors(dataset,quantizer)\n",
    "statis_result, max_value, min_value = statis_cluster(dataset, quantizer)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "image/png": "iVBORw0KGgoAAAANSUhEUgAAAh8AAAGgCAYAAAAKKQXsAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjguNCwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy8fJSN1AAAACXBIWXMAAA9hAAAPYQGoP6dpAAAfe0lEQVR4nO3df2yV5f3/8dfRwrHV06OonMMZVeo8mrmKM2Aq1dH6o91YZTMsbhPGMG4JWFA6snTW/uFxcaddk3XVNGJwS61Zmu4PwZEwsTVqnSnMCjSWujEXC1THsdOVcyrU00mvzx9+vb8cWrWnP67TU56P5Eq8r/s693kfLkpfXuf+4TLGGAEAAFhyTqoLAAAAZxfCBwAAsIrwAQAArCJ8AAAAqwgfAADAKsIHAACwivABAACsInwAAACrCB8AAMAqwgcAALAqqfCxaNEiuVyuUW3jxo2SJGOMQqGQAoGAMjMzVVRUpJ6enmkpHAAApKeMZAZ3dnbq1KlTzvbBgwdVXFysu+66S5JUW1ururo6Pf3007rqqqv06KOPqri4WIcOHZLH4xnXe4yMjOjf//63PB6PXC5XMuUBAIAUMcZocHBQgUBA55zzJWsbZhI2b95svvrVr5qRkREzMjJi/H6/qampcfZ//PHHxuv1mieffHLcx+zr6zOSaDQajUajpWHr6+v70t/1Sa18nG54eFh//OMftWXLFrlcLr3zzjuKRCIqKSlxxrjdbhUWFqqjo0Pr168f8zjxeFzxeNzZNv/vIbt9fX3Kzs6eaHkAAMCiWCymnJyccX3TMeHw8dxzz+n48eO65557JEmRSESS5PP5Esb5fD4dOXLkc49TXV2tRx55ZFR/dnY24QMAgDQznlMmJny1yx/+8AetWLFCgUDgC9/UGPOFhVRWVioajTqtr69voiUBAIA0MKGVjyNHjujFF1/U9u3bnT6/3y/p0xWQBQsWOP39/f2jVkNO53a75Xa7J1IGAABIQxNa+WhsbNT8+fNVWlrq9OXm5srv96utrc3pGx4eVnt7uwoKCiZfKQAAmBWSXvkYGRlRY2Oj1q1bp4yM//9yl8ul8vJyhcNhBYNBBYNBhcNhZWVlafXq1VNaNAAASF9Jh48XX3xRR48e1b333jtqX0VFhYaGhlRWVqaBgQHl5+ertbV13Pf4AAAAs5/LfHZt6wwRi8Xk9XoVjUa52gUAgDSRzO9vnu0CAACsInwAAACrCB8AAMAqwgcAALCK8AEAAKwifAAAAKsIHwAAwCrCBwAAsGpCD5ZDelr04K5RfYdrSscYCQDA9GHlAwAAWEX4AAAAVhE+AACAVYQPAABgFeEDAABYRfgAAABWET4AAIBVhA8AAGAV4QMAAFhF+AAAAFYRPgAAgFWEDwAAYBXhAwAAWEX4AAAAVhE+AACAVRmpLgDTZ9GDu1JdAgAAo7DyAQAArCJ8AAAAqwgfAADAKsIHAACwivABAACsInwAAACrCB8AAMAqwgcAALCK8AEAAKwifAAAAKsIHwAAwCrCBwAAsIrwAQAArCJ8AAAAqwgfAADAKsIHAACwivABAACsSjp8vPfee/rxj3+siy++WFlZWfrGN76hffv2OfuNMQqFQgoEAsrMzFRRUZF6enqmtGgAAJC+kgofAwMDuummmzRnzhw9//zzeuutt/Tb3/5WF154oTOmtrZWdXV1amhoUGdnp/x+v4qLizU4ODjVtQMAgDSUkczg3/zmN8rJyVFjY6PTt2jRIue/jTGqr69XVVWVVq1aJUlqamqSz+dTc3Oz1q9fP+qY8Xhc8Xjc2Y7FYsl+BgAAkEaSWvnYuXOnli5dqrvuukvz58/X9ddfr6eeesrZ39vbq0gkopKSEqfP7XarsLBQHR0dYx6zurpaXq/XaTk5ORP8KAAAIB0kFT7eeecdbd26VcFgUC+88II2bNigBx54QM8884wkKRKJSJJ8Pl/C63w+n7PvTJWVlYpGo07r6+ubyOcAAABpIqmvXUZGRrR06VKFw2FJ0vXXX6+enh5t3bpVP/nJT5xxLpcr4XXGmFF9n3G73XK73cnWDQAA0lRSKx8LFizQNddck9D3ta99TUePHpUk+f1+SRq1ytHf3z9qNQQAAJydkgofN910kw4dOpTQ989//lOXX365JCk3N1d+v19tbW3O/uHhYbW3t6ugoGAKygUAAOkuqa9dfv7zn6ugoEDhcFg/+MEP9Prrr2vbtm3atm2bpE+/bikvL1c4HFYwGFQwGFQ4HFZWVpZWr149LR8Ak7PowV0J24drSlNUCQDgbJFU+Ljhhhu0Y8cOVVZW6le/+pVyc3NVX1+vNWvWOGMqKio0NDSksrIyDQwMKD8/X62trfJ4PFNePAAASD8uY4xJdRGni8Vi8nq9ikajys7OTnU5ae3MVY3xYOUDADARyfz+5tkuAADAKsIHAACwivABAACsInwAAACrCB8AAMAqwgcAALCK8AEAAKwifAAAAKsIHwAAwCrCBwAAsIrwAQAArCJ8AAAAqwgfAADAKsIHAACwivABAACsInwAAACrCB8AAMAqwgcAALCK8AEAAKwifAAAAKsIHwAAwCrCBwAAsIrwAQAArCJ8AAAAqzJSXQBmh0UP7krYPlxTmqJKAAAzHSsfAADAKsIHAACwivABAACsInwAAACrCB8AAMAqwgcAALCK8AEAAKwifAAAAKsIHwAAwCrCBwAAsIrwAQAArCJ8AAAAqwgfAADAKsIHAACwivABAACsInwAAACrkgofoVBILpcrofn9fme/MUahUEiBQECZmZkqKipST0/PlBcNAADSV9IrH1//+td17Ngxp3V3dzv7amtrVVdXp4aGBnV2dsrv96u4uFiDg4NTWjQAAEhfSYePjIwM+f1+p1166aWSPl31qK+vV1VVlVatWqW8vDw1NTXp5MmTam5unvLCAQBAeko6fLz99tsKBALKzc3Vj370I73zzjuSpN7eXkUiEZWUlDhj3W63CgsL1dHR8bnHi8fjisViCQ0AAMxeSYWP/Px8PfPMM3rhhRf01FNPKRKJqKCgQB9++KEikYgkyefzJbzG5/M5+8ZSXV0tr9frtJycnAl8DAAAkC6SCh8rVqzQ97//fV177bW6/fbbtWvXLklSU1OTM8blciW8xhgzqu90lZWVikajTuvr60umJAAAkGYmdant+eefr2uvvVZvv/22c9XLmasc/f39o1ZDTud2u5WdnZ3QAADA7DWp8BGPx/X3v/9dCxYsUG5urvx+v9ra2pz9w8PDam9vV0FBwaQLBQAAs0NGMoN/8YtfaOXKlbrsssvU39+vRx99VLFYTOvWrZPL5VJ5ebnC4bCCwaCCwaDC4bCysrK0evXq6aofAACkmaTCx7vvvqu7775bH3zwgS699FLdeOON2rt3ry6//HJJUkVFhYaGhlRWVqaBgQHl5+ertbVVHo9nWooHAADpJ6nw0dLS8oX7XS6XQqGQQqHQZGoCAACzGM92AQAAVhE+AACAVYQPAABgVVLnfGDmWvTgrmk7zuGa0ik5NgAAEisfAADAMsIHAACwivABAACsInwAAACrOOE0DaT6JNAz358TUAEAk8HKBwAAsIrwAQAArCJ8AAAAqwgfAADAKsIHAACwivABAACsInwAAACrCB8AAMAqwgcAALCK8AEAAKzi9upI2li3ewcAYLxY+QAAAFYRPgAAgFWEDwAAYBXhAwAAWEX4AAAAVhE+AACAVYQPAABgFeEDAABYRfgAAABWET4AAIBVhA8AAGAV4QMAAFhF+AAAAFYRPgAAgFWEDwAAYBXhAwAAWEX4AAAAVhE+AACAVYQPAABgFeEDAABYlZHqAjA7LXpw16i+wzWlKagEADDTsPIBAACsmlT4qK6ulsvlUnl5udNnjFEoFFIgEFBmZqaKiorU09Mz2ToBAMAsMeHw0dnZqW3btmnx4sUJ/bW1taqrq1NDQ4M6Ozvl9/tVXFyswcHBSRcLAADS34TCx0cffaQ1a9boqaee0kUXXeT0G2NUX1+vqqoqrVq1Snl5eWpqatLJkyfV3Nw85rHi8bhisVhCAwAAs9eEwsfGjRtVWlqq22+/PaG/t7dXkUhEJSUlTp/b7VZhYaE6OjrGPFZ1dbW8Xq/TcnJyJlLSWWfRg7sSGgAA6SLp8NHS0qL9+/erurp61L5IJCJJ8vl8Cf0+n8/Zd6bKykpFo1Gn9fX1JVsSAABII0ldatvX16fNmzertbVV55133ueOc7lcCdvGmFF9n3G73XK73cmUAQAA0lhSKx/79u1Tf3+/lixZooyMDGVkZKi9vV2PP/64MjIynBWPM1c5+vv7R62GAACAs1NS4eO2225Td3e3urq6nLZ06VKtWbNGXV1duuKKK+T3+9XW1ua8Znh4WO3t7SooKJjy4gEAQPpJ6msXj8ejvLy8hL7zzz9fF198sdNfXl6ucDisYDCoYDCocDisrKwsrV69euqqBgAAaWvKb69eUVGhoaEhlZWVaWBgQPn5+WptbZXH45nqtwIAAGnIZYwxqS7idLFYTF6vV9FoVNnZ2akuZ0aYLZfS8mwXAJi9kvn9zbNdAACAVYQPAABgFeEDAABYRfgAAABWET4AAIBVhA8AAGAV4QMAAFhF+AAAAFYRPgAAgFWEDwAAYBXhAwAAWEX4AAAAVhE+AACAVRmpLgCjzZan2J7pzM/FU24B4OzEygcAALCK8AEAAKwifAAAAKsIHwAAwCpOOE2x2XpyKQAAn4eVDwAAYBXhAwAAWEX4AAAAVhE+AACAVYQPAABgFeEDAABYRfgAAABWET4AAIBVhA8AAGAV4QMAAFhF+AAAAFYRPgAAgFWEDwAAYBXhAwAAWEX4AAAAVhE+AACAVYQPAABgFeEDAABYRfgAAABWET4AAIBVhA8AAGAV4QMAAFiVVPjYunWrFi9erOzsbGVnZ2vZsmV6/vnnnf3GGIVCIQUCAWVmZqqoqEg9PT1TXjQAAEhfSYWPhQsXqqamRm+88YbeeOMN3Xrrrfre977nBIza2lrV1dWpoaFBnZ2d8vv9Ki4u1uDg4LQUDwAA0k9S4WPlypX6zne+o6uuukpXXXWVfv3rX+uCCy7Q3r17ZYxRfX29qqqqtGrVKuXl5ampqUknT55Uc3PzdNUPAADSzITP+Th16pRaWlp04sQJLVu2TL29vYpEIiopKXHGuN1uFRYWqqOj43OPE4/HFYvFEhoAAJi9kg4f3d3duuCCC+R2u7Vhwwbt2LFD11xzjSKRiCTJ5/MljPf5fM6+sVRXV8vr9TotJycn2ZIAAEAaSTp8XH311erq6tLevXt13333ad26dXrrrbec/S6XK2G8MWZU3+kqKysVjUad1tfXl2xJAAAgjWQk+4K5c+fqyiuvlCQtXbpUnZ2deuyxx/TLX/5SkhSJRLRgwQJnfH9//6jVkNO53W653e5kywAAAGlq0vf5MMYoHo8rNzdXfr9fbW1tzr7h4WG1t7eroKBgsm8DAABmiaRWPh566CGtWLFCOTk5GhwcVEtLi1555RXt3r1bLpdL5eXlCofDCgaDCgaDCofDysrK0urVq6erfgAAkGaSCh/vv/++1q5dq2PHjsnr9Wrx4sXavXu3iouLJUkVFRUaGhpSWVmZBgYGlJ+fr9bWVnk8nmkpHgAApB+XMcakuojTxWIxeb1eRaNRZWdnp7qcabfowV2pLiFlDteUproEAMAUSeb3N892AQAAVhE+AACAVYQPAABgFeEDAABYRfgAAABWET4AAIBVhA8AAGAV4QMAAFiV9IPlMDln803FAACQWPkAAACWET4AAIBVhA8AAGAV4QMAAFhF+AAAAFZxtQtmlPFcDXS4ptRCJQCA6cLKBwAAsIrwAQAArCJ8AAAAqwgfAADAKsIHAACwivABAACsInwAAACrCB8AAMAqwgcAALCKO5xOo/HcrRMAgLMNKx8AAMAqwgcAALCK8AEAAKwifAAAAKs44RQpwwm5AHB2YuUDAABYRfgAAABWET4AAIBVhA8AAGAV4QMAAFhF+AAAAFYRPgAAgFWEDwAAYBXhAwAAWEX4AAAAVhE+AACAVYQPAABgVVLho7q6WjfccIM8Ho/mz5+vO++8U4cOHUoYY4xRKBRSIBBQZmamioqK1NPTM6VFAwCA9JVU+Ghvb9fGjRu1d+9etbW16ZNPPlFJSYlOnDjhjKmtrVVdXZ0aGhrU2dkpv9+v4uJiDQ4OTnnxAAAg/WQkM3j37t0J242NjZo/f7727dun5cuXyxij+vp6VVVVadWqVZKkpqYm+Xw+NTc3a/369aOOGY/HFY/Hne1YLDaRzwEAANLEpM75iEajkqR58+ZJknp7exWJRFRSUuKMcbvdKiwsVEdHx5jHqK6ultfrdVpOTs5kSgIAADPchMOHMUZbtmzRzTffrLy8PElSJBKRJPl8voSxPp/P2XemyspKRaNRp/X19U20JAAAkAaS+trldJs2bdKbb76p1157bdQ+l8uVsG2MGdX3GbfbLbfbPdEyAABAmpnQysf999+vnTt36uWXX9bChQudfr/fL0mjVjn6+/tHrYYAAICzU1LhwxijTZs2afv27XrppZeUm5ubsD83N1d+v19tbW1O3/DwsNrb21VQUDA1FQMAgLSW1NcuGzduVHNzs/785z/L4/E4Kxxer1eZmZlyuVwqLy9XOBxWMBhUMBhUOBxWVlaWVq9ePS0fAGefRQ/uStg+XFOaokoAABORVPjYunWrJKmoqCihv7GxUffcc48kqaKiQkNDQyorK9PAwIDy8/PV2toqj8czJQUDAID0llT4MMZ86RiXy6VQKKRQKDTRmgAAwCzGs10AAIBVhA8AAGAV4QMAAFg14ZuMATMZV8QAwMzFygcAALCK8AEAAKwifAAAAKsIHwAAwCrCBwAAsIrwAQAArCJ8AAAAqwgfAADAKsIHAACwijucTtCZd9CUuItmuuEuqACQGqx8AAAAqwgfAADAKsIHAACwivABAACs4oTTKTTWSaiYfvy5A0B6YeUDAABYRfgAAABWET4AAIBVhA8AAGAV4QMAAFhF+AAAAFYRPgAAgFWEDwAAYBXhAwAAWEX4AAAAVhE+AACAVYQPAABgFeEDAABYRfgAAABWET4AAIBVhA8AAGAV4QMAAFhF+AAAAFYRPgAAgFWEDwAAYBXhAwAAWEX4AAAAViUdPl599VWtXLlSgUBALpdLzz33XMJ+Y4xCoZACgYAyMzNVVFSknp6eqaoXAACkuaTDx4kTJ3TdddepoaFhzP21tbWqq6tTQ0ODOjs75ff7VVxcrMHBwUkXCwAA0l9Gsi9YsWKFVqxYMeY+Y4zq6+tVVVWlVatWSZKamprk8/nU3Nys9evXT65aAACQ9qb0nI/e3l5FIhGVlJQ4fW63W4WFhero6BjzNfF4XLFYLKEBAIDZK+mVjy8SiUQkST6fL6Hf5/PpyJEjY76murpajzzyyFSWAUyZRQ/uStg+XFOaokoAYPaYlqtdXC5XwrYxZlTfZyorKxWNRp3W19c3HSUBAIAZYkpXPvx+v6RPV0AWLFjg9Pf3949aDfmM2+2W2+2eyjIAAMAMNqUrH7m5ufL7/Wpra3P6hoeH1d7eroKCgql8KwAAkKaSXvn46KOP9K9//cvZ7u3tVVdXl+bNm6fLLrtM5eXlCofDCgaDCgaDCofDysrK0urVq6e0cAAAkJ6SDh9vvPGGbrnlFmd7y5YtkqR169bp6aefVkVFhYaGhlRWVqaBgQHl5+ertbVVHo9n6qpOgTNPPAQAABOTdPgoKiqSMeZz97tcLoVCIYVCocnUBQAAZime7QIAAKwifAAAAKsIHwAAwKopvc8HMFNxwjAAzBysfAAAAKsIHwAAwCrCBwAAsIrwAQAArCJ8AAAAq7jaBZikM6+kOVxTmqJKACA9sPIBAACsInwAAACrCB8AAMAqwgcAALCKE07HwK24z07jmffp/LvBiasAzhasfAAAAKsIHwAAwCrCBwAAsIrwAQAArOKEUyCNcFIqgNmAlQ8AAGAV4QMAAFhF+AAAAFYRPgAAgFWccApMMZt3yB3rvcZzEionrgJIJVY+AACAVYQPAABgFeEDAABYRfgAAABWET4AAIBVXO0iu1cnAOPF30sAsxUrHwAAwCrCBwAAsIrwAQAArCJ8AAAAqzjhFMC4TPRW7jaN5yTdmVYzcDZi5QMAAFhF+AAAAFYRPgAAgFWEDwAAYNVZd8Ipd43EbDdVf8fHc5ypOsHzzOOM96RQmz/PM/1k1nQ4IXg2SMc/55lYMysfAADAqmkLH0888YRyc3N13nnnacmSJfrrX/86XW8FAADSyLSEjz/96U8qLy9XVVWVDhw4oG9+85tasWKFjh49Oh1vBwAA0si0nPNRV1enn/70p/rZz34mSaqvr9cLL7ygrVu3qrq6OmFsPB5XPB53tqPRqCQpFotNR2kaiZ+cluMC6Wysn7ep+lkZz8/yme813p//idQ40X9bxvNe0/Xv1niMVV8q65mt0vHP2VbNnx3TGPPlg80Ui8fj5txzzzXbt29P6H/ggQfM8uXLR41/+OGHjSQajUaj0WizoPX19X1pVpjylY8PPvhAp06dks/nS+j3+XyKRCKjxldWVmrLli3O9sjIiP773//q4osvlsvlmuryMAmxWEw5OTnq6+tTdnZ2qsvB52Ce0gPzlB6Yp/EzxmhwcFCBQOBLx07bpbZnBgdjzJhhwu12y+12J/RdeOGF01UWpkB2djY/hGmAeUoPzFN6YJ7Gx+v1jmvclJ9weskll+jcc88dtcrR398/ajUEAACcfaY8fMydO1dLlixRW1tbQn9bW5sKCgqm+u0AAECamZavXbZs2aK1a9dq6dKlWrZsmbZt26ajR49qw4YN0/F2sMTtduvhhx8e9TUZZhbmKT0wT+mBeZoeLmPGc01M8p544gnV1tbq2LFjysvL0+9+9zstX758Ot4KAACkkWkLHwAAAGPh2S4AAMAqwgcAALCK8AEAAKwifAAAAKsIH2eZV199VStXrlQgEJDL5dJzzz2XsN8Yo1AopEAgoMzMTBUVFamnpydhTDwe1/33369LLrlE559/vr773e/q3XffTRgzMDCgtWvXyuv1yuv1au3atTp+/Pg0f7rZobq6WjfccIM8Ho/mz5+vO++8U4cOHUoYwzyl3tatW7V48WLnzpfLli3T888/7+xnjmam6upquVwulZeXO33MVQpM8jlySDN/+ctfTFVVlXn22WeNJLNjx46E/TU1Ncbj8Zhnn33WdHd3mx/+8IdmwYIFJhaLOWM2bNhgvvKVr5i2tjazf/9+c8stt5jrrrvOfPLJJ86Yb3/72yYvL890dHSYjo4Ok5eXZ+644w5bHzOtfetb3zKNjY3m4MGDpqury5SWlprLLrvMfPTRR84Y5in1du7caXbt2mUOHTpkDh06ZB566CEzZ84cc/DgQWMMczQTvf7662bRokVm8eLFZvPmzU4/c2Uf4eMsdmb4GBkZMX6/39TU1Dh9H3/8sfF6vebJJ580xhhz/PhxM2fOHNPS0uKMee+998w555xjdu/ebYwx5q233jKSzN69e50xe/bsMZLMP/7xj2n+VLNPf3+/kWTa29uNMczTTHbRRReZ3//+98zRDDQ4OGiCwaBpa2szhYWFTvhgrlKDr13g6O3tVSQSUUlJidPndrtVWFiojo4OSdK+ffv0v//9L2FMIBBQXl6eM2bPnj3yer3Kz893xtx4443yer3OGIxfNBqVJM2bN08S8zQTnTp1Si0tLTpx4oSWLVvGHM1AGzduVGlpqW6//faEfuYqNabtqbZIP589DPDMBwD6fD4dOXLEGTN37lxddNFFo8Z89vpIJKL58+ePOv78+fNHPXAQX8wYoy1btujmm29WXl6eJOZpJunu7tayZcv08ccf64ILLtCOHTt0zTXXOL9smKOZoaWlRfv371dnZ+eoffw8pQbhA6O4XK6EbWPMqL4znTlmrPHjOQ4Sbdq0SW+++aZee+21UfuYp9S7+uqr1dXVpePHj+vZZ5/VunXr1N7e7uxnjlKvr69PmzdvVmtrq84777zPHcdc2cXXLnD4/X5JGpXS+/v7nf8r8Pv9Gh4e1sDAwBeOef/990cd/z//+c+o/7vA57v//vu1c+dOvfzyy1q4cKHTzzzNHHPnztWVV16ppUuXqrq6Wtddd50ee+wx5mgG2bdvn/r7+7VkyRJlZGQoIyND7e3tevzxx5WRkeH8OTJXdhE+4MjNzZXf71dbW5vTNzw8rPb2dhUUFEiSlixZojlz5iSMOXbsmA4ePOiMWbZsmaLRqF5//XVnzN/+9jdFo1FnDD6fMUabNm3S9u3b9dJLLyk3NzdhP/M0cxljFI/HmaMZ5LbbblN3d7e6urqctnTpUq1Zs0ZdXV264oormKtUSMFJrkihwcFBc+DAAXPgwAEjydTV1ZkDBw6YI0eOGGM+veTM6/Wa7du3m+7ubnP33XePecnZwoULzYsvvmj2799vbr311jEvOVu8eLHZs2eP2bNnj7n22mu55Gyc7rvvPuP1es0rr7xijh075rSTJ086Y5in1KusrDSvvvqq6e3tNW+++aZ56KGHzDnnnGNaW1uNMczRTHb61S7GMFepQPg4y7z88stG0qi2bt06Y8ynl509/PDDxu/3G7fbbZYvX266u7sTjjE0NGQ2bdpk5s2bZzIzM80dd9xhjh49mjDmww8/NGvWrDEej8d4PB6zZs0aMzAwYOlTprex5keSaWxsdMYwT6l37733mssvv9zMnTvXXHrppea2225zgocxzNFMdmb4YK7scxljTGrWXAAAwNmIcz4AAIBVhA8AAGAV4QMAAFhF+AAAAFYRPgAAgFWEDwAAYBXhAwAAWEX4AAAAVhE+AACAVYQPAABgFeEDAABY9X9b/ni7JlIFigAAAABJRU5ErkJggg==",
      "text/plain": [
       "<Figure size 640x480 with 1 Axes>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "最大值： 4503 最小值： 245\n"
     ]
    }
   ],
   "source": [
    "#a是一个一维数组，里面每个元素表示统计出的出现次数，画出一个统计出现次数的直方图\n",
    "plt.hist(statis_result, bins=100)\n",
    "plt.show()\n",
    "\n",
    "#统计a列表中元素的最大值、最小值和平均值\n",
    "print(\"最大值：\", max_value, \"最小值：\", min_value)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 13,
   "metadata": {},
   "outputs": [],
   "source": [
    "#ids记录了dataset中每个向量所属的聚类id\n",
    "ids = get_cluster_id_for_vectors(dataset,quantizer)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 14,
   "metadata": {},
   "outputs": [],
   "source": [
    "#按照ids中记录的每个向量的中心点id，将dataset中的向量整理为一个二维数组\n",
    "clustered_dataset = [[] for i in range(cluster_count)]\n",
    "\n",
    "for i in range(len(ids)):\n",
    "    clustered_dataset[ids[i]].append(i)\n",
    "\n",
    "#将clustered_dataset转化为一个np数组，其中记录的是clustered_dataset每行的元素数量\n",
    "clustered_dataset_statis = np.array([len(i) for i in clustered_dataset])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 15,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "[      0    1057    2348 ...  997897  999118 1000000]\n"
     ]
    }
   ],
   "source": [
    "#记录偏移量数组\n",
    "offset_list = np.zeros(shape=(cluster_count+1,),dtype=np.int32)\n",
    "for i in range(cluster_count):\n",
    "    offset_list[i] = np.sum(clustered_dataset_statis[:i])\n",
    "offset_list[-1] = dataset.shape[0]\n",
    "\n",
    "print(offset_list[:])\n",
    "#将该数组以二进制写入文件\n",
    "with open(offset_list_path,'wb') as f:\n",
    "    f.write(struct.pack('i',offset_list.shape[0]))\n",
    "    f.write(struct.pack('i',1))\n",
    "    f.write(offset_list.tobytes())\n",
    "\n",
    "#按照聚类结果，将dataset依次写入文件。每个向量的头4字节表示该向量的id（因为聚类后把顺序打乱了）\n",
    "with open(last_layer_path,'wb') as f:\n",
    "    #对于每个聚类，写入该聚类中的向量\n",
    "    for line in clustered_dataset:\n",
    "        for id in line:\n",
    "            #将id写入文件\n",
    "            f.write(struct.pack('i',id))\n",
    "            #将np数组dataset[id]写入文件\n",
    "            f.write(dataset[id].tobytes())"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "base",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.11.5"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
